

% inputs:
% regexp is the regular expression that preceeds the movie # for all the
% movies in the directory
%
% outputs:
% trajectoriesDDD - each column is a time point aligned to the start of the
% stimulus (offset by prestimTime, fps, poststimTime) in deltaD/D array;
% baseline is determined as average position over the 2 second window prior
% to stim presentation.
% Each row is a seperate individual stimulation presentation
%
% trajectoriesD - same as trajectoriesDDD but with just the raw diameter in
% microns
%
% trajectoryinfo - contains information about what happened in that
% particular stim; the rows correspond to the rows of trajectories.
% first column is the # of the movie
% second column is the vessel ID corresponding to that trajectory
% Third-fifth columns are the position of the vessel in the window (x,y,z
% in microns)
% sixth and seventh columns are the STARTING position of the stim on the screen
% (x,y as a fraction of the screen), eighth is the movie ID (1 if only
% using one kind of movie)
% ninth and tenth is ENDING position of the stim (x,y). 
% eleventh is the size of the stim (in radians if circles; pixels if
% square)
% twelth is the duration of the stimulus in seconds
% if there are multiple stims presented simultaneously, then the columns
% will continue with the positions (i.e. if there are two stims, then 6th
% column is x-position of the first stim box, 7th column is y-position of
% the first stim box, 8th is the movie played in first stim,...
% 13th column is x-position of the second stim box, 14th
% column is y-position of the second stim box etc.)
% 
% 13th (+7*(numStims-1)) column is the global vessel ID which is a unique identifier for
% each vessel imaged in the whole window
% 14th (+7*(numStims-1)) column is the session#
%
% the above arrays are the results of averaging the stims in a given movie
% (i.e. if you present a particular stim 3 times in one movie, then you
% average over those 3 presentations). including a structure called
% "rawResults" which has the raw results of this without averaging,
% including the running data
%
% updates from v5:
% rewriting to include ability to work with moving stimuli; note this
% changes the size of trajectoryinfo and associated tables (see above)

function [trajectoriesDDD,trajectoryinfo,piaTracking,trajectoriesD,running,rawResults,trajectoriesGCaMP,pupilDisplacement,pupilRadius,stiminfotable] = poolStimData_v5_3(regExpindicator,piaTracking,session,objType,prestimTime,poststimTime)

if nargin < 4 || isempty(objType)
    defaultObjType = '16x';
else
    defaultObjType = objType;
end
dowaveletfilter = true;             %option to pre-filter the trajectories before averaging
if nargin < 5 || isempty(prestimTime)
    prestimTime = 4.5;
end
fps = 10;
if nargin < 6 || isempty(poststimTime)
    poststimTime = 9.5;
end
prestimBehavior = 3;
poststimBehavior = 3;
prestimBaseline = 2;
runningthresh = 300;     % exclude trajectories where mouse is running faster than this threshold
pupilthresh = 100;        % exclude trajectories where pupil is beyond this radius
fn = dir('*groupedResults.mat');

if nargin < 3 || isempty(session)
    session = input('which imaging session is this: ');
end
%% append some metadata into the results structures:
if nargin < 2 || isempty(piaTracking)
    piaTracking = [];
    for n = 1:length(fn)
        load(fn(n).name)
        if ~isfield(results,'kymos')        % this screens out diving vessel movies
            continue;
        end
        if isempty(results.kymos)
            continue;
        end
        if ~isfield(results,'validtracking')
            disp('unable to compile results')
            continue;
        end
        if contains(fn(n).name,regExpindicator{2})
            continue;
        end
        strind = strfind(fn(n).name,regExpindicator{1});
        if isempty(strind)
            continue;
        end
        try
            [vismetadata,eyetrackingname] = getSupportFiles(fn(n).name,results.movname,regExpindicator{2});
        catch
            disp('unable to detect either metadata filename or eye tracking filename');
        end
        
        
        %get the corresponding stim metadata name:

        try
            [results.eyeTracking,results.eyeTrackingMovname] = getEyeTracking(eyetrackingname);
        catch
            results.eyeTracking = nan(size(results.rawTracking,2)*size(results.rawTracking,3),2);
            results.eyeTrackingMovname = [];
            disp('cannot get eye tracking')
        end
        if size(results.eyeTracking,1) < (size(results.rawTracking,2)*size(results.rawTracking,3))
            results.eyeTracking = [nan((size(results.rawTracking,2)*size(results.rawTracking,3))-size(results.eyeTracking,1),2);results.eyeTracking];
        elseif size(results.eyeTracking,1) > (size(results.rawTracking,2)*size(results.rawTracking,3))
            curframes = (size(results.rawTracking,2)*size(results.rawTracking,3));
            results.eyeTracking = results.eyeTracking(1:curframes,:);
            disp('eye tracking longer than expected')
        end
        results.running = results.ballTracking;
        fieldsToRemove = {'stimID','ballTracking'};
        results = rmfield(results,fieldsToRemove);
        results.eyeTracking = permute(reshape(results.eyeTracking,size(results.rawTracking,2),size(results.rawTracking,3),2),[2,1,3]);  
        [results.stiminfo,results.uniqueID,results.stimMovieParams] = getStimInfo_v4_4(results.movname,vismetadata,results);
        if isfield(results,'directory')
            results = rmfield(results,'directory');
        end
        
        results.directory = cd;
        if exist('globalVesselID','var')
            results.globalVesselID = globalVesselID;        %this value gives a unique identifier to each vessel imaged in the whole window; needs to be consistent across days. use the function "assignVesselID_manual_dev.m" currently to do so (11/20/2020)
        else
            results.globalVesselID = [];
        end
        if ~isfield(results,'roughedges')
            results.roughedges = results.edgepos;
        end
        if isfield(results,'alternateChannelKymos')
            curzoom = getSImetadata(results.movname);pix2mic = 1000/curzoom.Height/curzoom.zoom;
            curzoom = curzoom.zoom;
            ratio = nan(size(results.rawTracking));
            for n2 = 1:size(results.kymos,1)
                for n3 = 1:size(results.kymos,2)
                    ratio(n3,:,n2) = sensorRatioFromTracking(results.kymos{n2,n3},results.alternateChannelKymos{n2,n3},...
                        results.edgepos{n2,n3},pix2mic);
                end
            end
            results.ratio = ratio;
        end
        if ~isfield(results,'objtype')
            results.objtype = defaultObjType;
        end
        if results.preStimTime < prestimTime
            prestimTime = results.preStimTime - .5;
        end
            
        piaTracking = [piaTracking;results];
        clear results validtracking globalVesslID
    end
end

if isempty(piaTracking)
    trajectoriesD = [];
    trajectoriesDDD = [];
    trajectoryinfo = [];
    running = [];
    rawResults = [];
    trajectoriesGCaMP = [];
    pupilDisplacement = [];
    pupilRadius = [];
    stiminfotable = [];
    return;
end

rawArray = [];      %this stores deltaD/D
rawRatio = [];      %this stores the ratiometric GCaMP signal
moviename = [];
rawArrayDiameter = [];  %this stores the raw diameter (in microns)
arrayInfo = [];
running = [];
pupilDisplacement = [];
pupilRadius = [];
%%
for n = 1:length(piaTracking)       %iterate through each movie

    try
        meta = getSImetadata(piaTracking(n).movname(1));
        if isfield(piaTracking(n),'objtype')
            pix2mic = pixel2micron_2p(meta.zoom,piaTracking(n).objtype);
        else
            pix2mic = pixel2micron_2p(meta.zoom,defaultObjType);
        end
    catch
        disp('zoom error!')
    end
    
    %% filter out outlier points:
    temp = nan(size(piaTracking(n).rawTracking));
    for n2 = 1:size(temp,1)
        for n3 = 1:size(temp,3)
            if piaTracking(n).validtracking(n3,n2) == 1
                curtracking = piaTracking(n).edgepos{n3,n2};
                curtracking = abs(diff(curtracking,1,1));
                if sum(isnan(curtracking))> .5*length(curtracking)
                    piaTracking(n).validtracking(n3,n2) = 0;
                else
                    temp(n2,:,n3) = removeOutlierPoints(curtracking);
                end
            end
        end
    end
    rawTracking = temp;
    
    if isfield(piaTracking(n),'ratio')
        rawRatio = piaTracking(n).ratio;
    else
        rawRatio = nan;
    end
    vesselID = piaTracking(n).vesselID;
    uniqueVessels = unique(vesselID,'stable');
    vesselPositions = piaTracking(n).vesselPositions;
    up = piaTracking(n).uniqueID;
    if isfield(piaTracking,'globalVesselID')
        globalVesselID = piaTracking(n).globalVesselID;
    else
        globalVesselID = [];
    end
    rn = piaTracking(n).running;
    eyeTrack = piaTracking(n).eyeTracking;
    DDD = cell(size(uniqueVessels));
    D = cell(size(uniqueVessels));
    % average over the kymographs from each individual vessel
    for n2 = 1:length(uniqueVessels)     %iterate through each vessel
        ind = find(vesselID == uniqueVessels(n2));
        baselineFrames = ((piaTracking(n).preStimTime-prestimBaseline)*fps):(piaTracking(n).preStimTime*fps);
        [a,b] = averageForVessel(rawTracking,ind,baselineFrames,dowaveletfilter);
        b = b*pix2mic;
        DDD{n2} = permute(a,[3,2,1]);D{n2} = permute(b,[3,2,1]);
    end
    % average over the same stimuli, filtering out epochs of running or
    % pupil dilation
    temp = piaTracking(n).stiminfo;     %this used to be 1, but I think that was a mistake...
    temp = temp(2,(piaTracking(n).preStimTime*fps+1),:);
    temp = permute(temp,[3,1,2]);
    curstimpos = zeros(size(temp,1),size(up,2));
    for n2 = 1:size(up,1)
        ind = find(temp == n2);
        curstimpos(ind,:) = repmat(up(n2,:),numel(ind),1);
    end
    for n2 = 1:size(DDD,1)
        rawArray = [rawArray;DDD{n2}];
        rawArrayDiameter = [rawArrayDiameter;D{n2}];
        running = [running;rn(1:size(D{n2},1),:)];
        pupilRadius = [pupilRadius;eyeTrack(:,:,2)];
        pupilDisplacement = [pupilDisplacement;eyeTrack(:,:,1)];
        try
            % note that here you will rely on the third column of stim
            % positions to indicate that you have a moving versus
            % stationary stim and that'll change downstream processing
            temp = [repmat(n,size(D{n2},1),1),repmat(uniqueVessels(n2),size(D{n2},1),1),...
                repmat(vesselPositions(n2,:),size(D{n2},1),1),curstimpos];
        catch
            disp('truncatingstimpositions')
            tempstimpos = curstimpos(1:size(D{n2},1),:);
            temp = [repmat(n,size(D{n2},1),1),repmat(uniqueVessels(n2),size(D{n2},1),1),...
                repmat(vesselPositions(n2,:),size(D{n2},1),1),tempstimpos];
        end
        if isempty(globalVesselID)
            temp = [temp,nan(size(temp,1),1)];
        else
            try
            temp = [temp,repmat(globalVesselID(uniqueVessels(n2)),size(temp,1),1)];
            catch
                'a'
            end
        end
        temp = [temp,repmat(session,size(temp,1),1)];
        try
        arrayInfo = [arrayInfo;temp];
        catch
            'b'
        end
    end
end

trajectoriesGCaMP = rawRatio;
trajectoriesDDD = rawArray;     %this stores deltaD/D trajectory
trajectoriesD = rawArrayDiameter;   %this stores just the diameter
trajectoryinfo = arrayInfo;
numstimboxes = round(size(arrayInfo,2)-7)/7;
% convert the array info into a table:
tableheader = {'movieNumber','localVesselID',...
    'vesselX_position','vesselY_position','vesselZ_position'};
tableheader = [tableheader,repmat({'stimStartX_position','stimStartY_position','stimMovieID',...
    'stimEndX_position','stimEndY_position','stimMovieSize','stimMovieDuration'},1,numstimboxes)];
tableheader = [tableheader,{'globalVesselID','sessionNumber'}];
stiminfotable = array2table(trajectoryinfo,'VariableNames',tableheader);


rawResults.trajectoriesDDD = trajectoriesDDD;
rawResults.trajectoriesD = trajectoriesD;
rawResults.trajectoryinfo = stiminfotable;
rawResults.running = running;
rawResults.pupilDisplacement = pupilDisplacement;
rawResults.pupilRadius = pupilRadius;

%% filter out trajectories that have too much running or pupil dilation:
stimduration = sum(piaTracking(1).stiminfo(3,:,1));
behaviorframes = ((piaTracking(1).preStimTime-prestimBehavior)*fps):((piaTracking(1).preStimTime+poststimBehavior)*fps+stimduration);
behaviorframes = behaviorframes(behaviorframes>0);
behaviorframes = behaviorframes(behaviorframes <= size(running,2));
avRun = nanmean(running(:,behaviorframes),2);
avPupil = nanmean(pupilRadius(:,behaviorframes),2);
i1 = find(avRun > runningthresh);
i2 = find(avPupil > pupilthresh);
ind = union(i1,i2);
trajectoriesDDD(ind,:) = nan;
trajectoriesD(ind,:) = nan;

%% crop to within the specificed prestim/poststim bounds:

validframes = (1+(piaTracking(1).preStimTime-prestimTime)*fps):((piaTracking(1).preStimTime+poststimTime)*fps+stimduration);

trajectoriesDDD = trajectoriesDDD(:,validframes);

trajectoriesD = trajectoriesD(:,validframes);
pupilRadius = pupilRadius(:,validframes);
pupilDisplacement = pupilDisplacement(:,validframes);
if max(validframes)>size(running,2)
    running = running(:,validframes(validframes<size(running,2)));
    disp('running is too short')
else
    running = running(:,validframes);
end



function [r,movname] = getEyeTracking(filename)
load(filename)
nf = size(center,1);
vi = find(~isnan(center(:,1)));
blanki = find(isnan(center(:,1)));
if numel(blanki)/nf > .7       %don't bother including movies where more than this threshold is non-valid
    r = nan(nf,2);
end
medpos = nanmedian(center,1);
center(:,1) = center(:,1)-medpos(1);
center(:,2) = center(:,2)-medpos(2);
center = sqrt(center(:,1).^2 + center(:,2).^2);
radius = radius(:,1);
center(blanki) = interp1(vi,center(vi),blanki);
radius(blanki) = interp1(vi,radius(vi),blanki);
blanki = find(radius<3); 
radius(blanki) = nan;
center(blanki) = nan;
r = [center,radius];
return;


function [stimtraj,stimtrajRaw] = averageForVessel(vesselTracking,ind,baselineFrames,dowaveletfilter)
% all this function does is average the tracking for all kymographs from a
% given vessel. Outputs are stimtraj, which is deltaD/D and stimtrajRaw
% which is just diameter (in pixels)

if dowaveletfilter
    for n = 1:length(ind)
        for n2 = 1:size(vesselTracking,3)
            cur = vesselTracking(ind(n),:,n2);
            missing = isnan(cur);
            if sum(missing) > 0 && sum(missing)<(.25*length(cur))
                cur = resample(cur,1:length(cur));
            elseif sum(missing)>(.25*length(cur))
                continue;
            end
            vesselTracking(ind(n),:,n2) = wdenoise(cur,3, ...
                'Wavelet', 'sym4', ...
                'DenoisingMethod', 'Bayes', ...
                'ThresholdRule', 'Median', ...
                'NoiseEstimate', 'LevelDependent');
        end
    end
end
baselineD = nanmedian(vesselTracking(ind,baselineFrames,:),2);
stimtrajRaw = vesselTracking(ind,:,:);
stimtraj = stimtrajRaw./repmat(baselineD,1,size(stimtrajRaw,2),1);
stimtraj = nanmedian(stimtraj,1);
stimtrajRaw = nanmedian(stimtrajRaw,1);
return;


function [trajAVG,trajinfoAVG] = averageWithinMovie(traj,trajinfo)
% averages identical stimuli within a given movie
sortinfo = trajinfo(:,[1;2;6;7;8;10]);
[~,ia,ib] = unique(sortinfo,'rows','stable');
trajinfoAVG = zeros(size(ia,1),size(trajinfo,2));
trajAVG = zeros(size(ia,1),size(traj,2));
for n = 1:size(ia,1)
    ind = find(ib == n);
    trajinfoAVG(n,:) = trajinfo(ia(n),:);
    trajAVG(n,:) = nanmean(traj(ind,:),1);
end
return;



function [metadataname,eyetrackingname] = getSupportFiles(groupedTrackingName,movname2p,forbiddenstring)

exception = '[^ \f\n\r\t\v.,;:_-]*';

temp = strfind(groupedTrackingName,'grouped');
repeatnum = groupedTrackingName((temp-3):(temp-2));

rt = regexp(groupedTrackingName,exception,'match'); rt = [rt{:}];

temp = strfind(rt,[repeatnum,'grouped']);
rt = rt(1:(temp-1));

fn = dir('*Metadata*.mat');
for n = 1:length(fn)
	cur = regexp(fn(n).name,exception,'match'); cur = [cur{:}];
    if contains(cur,forbiddenstring)
        continue;
    end
    if contains(cur,rt)
        temp = regexp(convertCharsToStrings(cur),'\d*','Match');
        if strcmp(temp{end},repeatnum)
            metadataname = fn(n).name;
        end
    end
end

temp = movname2p{1}(1:(length(movname2p{1})-10));
fn = dir([temp,'*','EyeTracking.mat']);
if isempty(fn)
    eyetrackingname = [];
else
    eyetrackingname = fn(1).name;
end

return;

